"""
Scenario analysis and projection module.

Implements:
1. Demographic projections using UN WPP medium-variant (2025-2060)
2. Counterfactual scenarios (China fertility, SSA institutions, Japan rates)
3. Residual decomposition
4. Country-level demographic contribution time series
"""

import pandas as pd
import numpy as np
from pathlib import Path

PROCESSED_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/processed")
OUTPUT_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/output")


def project_demographic_contribution(model, demo_projections_df, feature_names,
                                     start_year=2025, end_year=2060):
    """
    Project demographic contribution to CA using UN WPP medium-variant
    population projections.

    Parameters
    ----------
    model : fitted PanelGLS model
    demo_projections_df : DataFrame with iso3, year, Z_1, Z_2, Z_3
                          for years start_year..end_year
    feature_names : list of feature names matching model.beta
    start_year, end_year : projection horizon
    """
    z_indices = [i for i, name in enumerate(feature_names) if name.startswith('Z_')]
    z_names = [feature_names[i] for i in z_indices]
    z_betas = model.beta[z_indices]

    proj = demo_projections_df[
        (demo_projections_df['year'] >= start_year) &
        (demo_projections_df['year'] <= end_year)
    ].copy()

    proj['demo_contribution'] = sum(
        z_betas[i] * proj[z_names[i]] for i in range(len(z_names))
    )

    return proj[['iso3', 'year', 'demo_contribution'] + z_names]


def compute_country_profiles(model, panel_df, feature_names,
                             countries=None, years=None):
    """
    Compute detailed demographic contribution profiles for selected countries.

    Returns:
    - Time series of demographic contribution
    - Decomposition into Z_1, Z_2, Z_3 components
    - Actual vs fitted CA/GDP
    """
    from .demographics import recover_age_coefficients

    if countries is None:
        countries = ['CHN', 'IND', 'IDN', 'JPN', 'USA', 'DEU', 'BRA', 'NGA', 'ZAF']

    z_indices = [i for i, name in enumerate(feature_names) if name.startswith('Z_')]
    z_names = [feature_names[i] for i in z_indices]
    z_betas = model.beta[z_indices]

    gamma_hat = z_betas  # These are γ_1, γ_2, γ_3
    alpha = recover_age_coefficients(gamma_hat)

    profiles = {}
    for country in countries:
        cdf = panel_df[panel_df['iso3'] == country].copy()
        if years:
            cdf = cdf[(cdf['year'] >= years[0]) & (cdf['year'] <= years[1])]

        if len(cdf) == 0:
            continue

        cdf['demo_contribution'] = sum(
            z_betas[i] * cdf[z_names[i]] for i in range(len(z_names))
        )

        # Component decomposition
        for i, z in enumerate(z_names):
            cdf[f'contribution_{z}'] = z_betas[i] * cdf[z]

        profiles[country] = cdf

    return profiles, alpha


def china_counterfactual(demo_shares_df, gdp_df=None):
    """
    Counterfactual: What if China's fertility had stayed at replacement (TFR=2.1)?

    This is a simplified counterfactual — we shift China's age distribution
    toward a hypothetical with more young people and fewer old people,
    calibrated to approximate the effect of replacement-level fertility
    since the one-child policy (~1980).

    Approach: Use Indonesia's age structure trajectory (similar initial
    conditions, no one-child policy) as a guide for the counterfactual.
    """
    from .demographics import construct_polynomial_variables, G

    share_cols = [f'n_{g}' for g in range(1, G + 1)]

    chn = demo_shares_df[demo_shares_df['iso3'] == 'CHN'].copy()
    idn = demo_shares_df[demo_shares_df['iso3'] == 'IDN'].copy()

    if len(chn) == 0 or len(idn) == 0:
        print("Need both CHN and IDN data for counterfactual")
        return None

    # For years after 1980, blend China's actual with Indonesia's structure
    # to approximate what replacement fertility would have produced
    counterfactual = chn.copy()
    counterfactual['iso3'] = 'CHN_cf'

    for year in counterfactual['year'].unique():
        if year <= 1980:
            continue  # Before policy, keep actual

        chn_year = chn[chn['year'] == year]
        idn_year = idn[idn['year'] == year]

        if len(chn_year) == 0 or len(idn_year) == 0:
            continue

        # Blending weight increases with years since policy
        years_since = min((year - 1980) / 40, 1.0)  # Full blend by 2020
        weight = years_since * 0.5  # Don't go fully to IDN structure

        for col in share_cols:
            actual = chn_year[col].values[0]
            alternative = idn_year[col].values[0]
            mask = counterfactual['year'] == year
            counterfactual.loc[mask, col] = actual * (1 - weight) + alternative * weight

        # Renormalize shares to sum to 1
        mask = counterfactual['year'] == year
        total = counterfactual.loc[mask, share_cols].sum(axis=1)
        for col in share_cols:
            counterfactual.loc[mask, col] /= total.values

    return counterfactual


def ssa_institutional_scenario(panel_df, model, feature_names):
    """
    Scenario: What if Sub-Saharan African countries achieved
    East Asian institutional quality?

    Approach: Set SSA countries' KAOPEN and governance indicators
    to East Asian average levels, and re-predict CA/GDP.
    """
    from .macro import SSA_COUNTRIES

    east_asian = ['KOR', 'SGP', 'MYS', 'THA', 'TWN']
    df = panel_df.copy()

    # Get East Asian average KAOPEN
    ea_mask = df['iso3'].isin(east_asian)
    if 'kaopen' in df.columns:
        ea_kaopen = df.loc[ea_mask].groupby('year')['kaopen'].mean()

        # Replace SSA KAOPEN with EA average
        ssa_mask = df['iso3'].isin(SSA_COUNTRIES)
        scenario = df[ssa_mask].copy()

        for year in scenario['year'].unique():
            if year in ea_kaopen.index:
                year_mask = scenario['year'] == year
                scenario.loc[year_mask, 'kaopen'] = ea_kaopen[year]

        # Re-compute interaction terms if present
        for z in ['Z_1', 'Z_2', 'Z_3']:
            int_col = f'{z}_x_kaopen'
            if int_col in feature_names and z in scenario.columns:
                scenario[int_col] = scenario[z] * scenario['kaopen']

        return scenario

    return None


def japan_rate_normalization_scenario(panel_df, model, feature_names):
    """
    Scenario: What if Japan's interest rates normalize to 2-3%?

    Assess the impact on carry trade incentives and predicted CA/GDP.
    """
    df = panel_df[panel_df['iso3'] == 'JPN'].copy()

    if 'real_bond_10y_diff' not in df.columns:
        print("No rate differential data for Japan scenario")
        return None

    # Current Japan rates are near zero; normalize to 2.5%
    # This affects: real_bond_10y_diff, carry variables
    scenario = df.copy()

    # Assume Japan's nominal rate goes from ~0% to 2.5%
    # Real rate goes from ~-1% to ~0.5% (assuming 2% inflation)
    # Change in real rate diff ≈ +1.5 pp
    if 'real_bond_10y_diff' in scenario.columns:
        scenario['real_bond_10y_diff'] += 1.5

    if 'carry_vs_jpn' in scenario.columns:
        scenario['carry_vs_jpn'] -= 2.5  # Less carry advantage for others

    return scenario


def decompose_residuals(model, panel_df, feature_names):
    """
    Decompose model residuals into components:
    1. Demographic contribution
    2. Macro fundamentals contribution
    3. Unexplained residual
    """
    # Ensure feature_names matches model dimensions
    n_beta = len(model.beta)
    if len(feature_names) != n_beta:
        print(f"  Warning: feature_names ({len(feature_names)}) != model.beta ({n_beta})")
        feature_names = feature_names[:n_beta]

    z_indices = [i for i, name in enumerate(feature_names) if name.startswith('Z_')]
    ctrl_indices = [i for i, name in enumerate(feature_names) if not name.startswith('Z_')]

    z_names = [feature_names[i] for i in z_indices]
    ctrl_names = [feature_names[i] for i in ctrl_indices]

    df = panel_df.copy()

    # Demographic component
    df['demo_component'] = sum(
        model.beta[z_indices[i]] * df[z_names[i]] for i in range(len(z_names))
    )

    # Macro controls component
    df['macro_component'] = model.constant
    for i, name in enumerate(ctrl_names):
        if name in df.columns:
            df['macro_component'] += model.beta[ctrl_indices[i]] * df[name].fillna(0)

    # Residual
    if 'ca_gdp' in df.columns:
        df['unexplained'] = df['ca_gdp'] - df['demo_component'] - df['macro_component']

    result_cols = ['iso3', 'year']
    if 'ca_gdp' in df.columns:
        result_cols.append('ca_gdp')
    result_cols.extend(['demo_component', 'macro_component', 'unexplained'])
    available = [c for c in result_cols if c in df.columns]
    return df[available]


def generate_projection_table(profiles, countries=None, years_to_show=None):
    """Generate a formatted table of demographic CA contributions by country and year."""
    if years_to_show is None:
        years_to_show = [2000, 2010, 2020, 2030, 2040, 2050, 2060]

    if countries is None:
        countries = list(profiles.keys())

    rows = []
    for country in countries:
        if country not in profiles:
            continue
        cdf = profiles[country]
        row = {'Country': country}
        for year in years_to_show:
            year_data = cdf[cdf['year'] == year]
            if len(year_data) > 0:
                row[str(year)] = year_data['demo_contribution'].values[0]
            else:
                row[str(year)] = np.nan
        rows.append(row)

    table = pd.DataFrame(rows)
    return table


def _load_interaction_coefficients(coeff_path=None):
    """Load Z×KAOPEN interaction coefficients from the extended model CSV.

    Returns dict with keys: beta_kaopen, delta_1, delta_2, delta_3.
    """
    if coeff_path is None:
        coeff_path = OUTPUT_DIR / "tables" / "regression_extended_plus_interactions.csv"
    coeffs = pd.read_csv(coeff_path)
    coeff_map = dict(zip(coeffs['variable'], coeffs['coefficient']))
    return {
        'beta_kaopen': coeff_map.get('kaopen', 0),
        'delta_1': coeff_map.get('Z_1_x_kaopen', 0),
        'delta_2': coeff_map.get('Z_2_x_kaopen', 0),
        'delta_3': coeff_map.get('Z_3_x_kaopen', 0),
    }


def _marginal_effect(z1, z2, z3, coeffs):
    """Compute dCA/dKAOPEN = beta_kaopen + delta_1*Z1 + delta_2*Z2 + delta_3*Z3."""
    return (coeffs['beta_kaopen']
            + coeffs['delta_1'] * z1
            + coeffs['delta_2'] * z2
            + coeffs['delta_3'] * z3)


def compute_openness_marginal_effects(panel_df, polys_df, coeff_path=None):
    """
    Deliverable A: Marginal effect of openness (dCA/dKAOPEN) for every country
    at current demographics and projected 2030/2040/2050 demographics.

    Returns DataFrame with iso3, current marginal effect, and projected effects.
    Output: output/tables/openness_marginal_effects.csv
    """
    coeffs = _load_interaction_coefficients(coeff_path)
    print(f"  Interaction coefficients: {coeffs}")

    # Current demographics: latest historical year (≤2024) with Z data
    hist = panel_df[panel_df['year'] <= 2024].dropna(subset=['Z_1', 'Z_2', 'Z_3'])
    current = (hist.sort_values('year')
               .groupby('iso3')
               .last()
               .reset_index())

    rows = []
    for _, r in current.iterrows():
        iso = r['iso3']
        me_current = _marginal_effect(r['Z_1'], r['Z_2'], r['Z_3'], coeffs)
        row = {
            'iso3': iso,
            'current_year': int(r['year']),
            'current_kaopen': r.get('kaopen', np.nan),
            'marginal_effect_current': me_current,
        }

        # Projected demographics from polys_df
        for proj_year in [2030, 2040, 2050]:
            proj = polys_df[(polys_df['iso3'] == iso) & (polys_df['year'] == proj_year)]
            if len(proj) > 0:
                p = proj.iloc[0]
                me = _marginal_effect(p['Z_1'], p['Z_2'], p['Z_3'], coeffs)
                row[f'marginal_effect_{proj_year}'] = me
            else:
                row[f'marginal_effect_{proj_year}'] = np.nan

        rows.append(row)

    result = pd.DataFrame(rows)
    result = result.sort_values('marginal_effect_current', ascending=False)
    result.to_csv(OUTPUT_DIR / "tables" / "openness_marginal_effects.csv", index=False)

    print(f"  Marginal effects computed for {len(result)} countries")
    print(f"  Top 5 (largest positive dCA/dKAOPEN):")
    for _, r in result.head(5).iterrows():
        print(f"    {r['iso3']}: {r['marginal_effect_current']:+.2f} pp per KAOPEN unit")
    print(f"  Bottom 5 (largest negative dCA/dKAOPEN):")
    for _, r in result.tail(5).iterrows():
        print(f"    {r['iso3']}: {r['marginal_effect_current']:+.2f} pp per KAOPEN unit")

    return result


def compute_openness_scenarios(panel_df, polys_df, coeff_path=None):
    """
    Deliverables B + C: Opening scenarios for EMs and closing scenarios for AEs.

    Opening: compute CA effect of moving KAOPEN to EA avg, 75th pctile, max.
    Closing: compute CA cost of KAOPEN dropping by 0.5, 1.0, 1.5 units.

    Returns (opening_df, closing_df).
    Output: output/tables/opening_scenarios.csv, output/tables/closing_scenarios.csv
    """
    coeffs = _load_interaction_coefficients(coeff_path)

    # Get latest historical KAOPEN per country (≤2024)
    hist = panel_df[panel_df['year'] <= 2024].dropna(subset=['kaopen', 'Z_1'])
    latest = (hist.sort_values('year')
              .groupby('iso3')
              .last()
              .reset_index())
    kaopen_by_country = dict(zip(latest['iso3'], latest['kaopen']))

    # KAOPEN targets for opening scenarios
    all_kaopen = latest['kaopen']
    kaopen_max = all_kaopen.max()
    kaopen_open_mean = all_kaopen[all_kaopen > 0].mean()  # Mean of financially open countries
    ea_countries = ['KOR', 'SGP', 'MYS', 'THA']
    ea_kaopen = latest[latest['iso3'].isin(ea_countries)]['kaopen'].mean()

    print(f"  KAOPEN targets: EA avg={ea_kaopen:.2f}, open-country mean={kaopen_open_mean:.2f}, max={kaopen_max:.2f}")

    # --- Opening scenarios (low-KAOPEN countries) ---
    opening_countries = ['IND', 'NGA', 'IDN', 'VNM', 'ETH', 'CHN', 'BGD',
                         'PHL', 'EGY', 'PAK', 'COL', 'ARG', 'BRA', 'TUR',
                         'KEN', 'GHA', 'TZA', 'UGA', 'MOZ', 'SEN']
    # Filter to those actually in the data
    opening_countries = [c for c in opening_countries if c in kaopen_by_country]

    opening_rows = []
    for iso in opening_countries:
        current_kaopen = kaopen_by_country[iso]

        for proj_year in ['current', 2030, 2040, 2050]:
            if proj_year == 'current':
                row_data = latest[latest['iso3'] == iso]
                if len(row_data) == 0:
                    continue
                z1, z2, z3 = row_data.iloc[0][['Z_1', 'Z_2', 'Z_3']]
                year_label = int(row_data.iloc[0]['year'])
            else:
                proj = polys_df[(polys_df['iso3'] == iso) & (polys_df['year'] == proj_year)]
                if len(proj) == 0:
                    continue
                z1, z2, z3 = proj.iloc[0][['Z_1', 'Z_2', 'Z_3']]
                year_label = proj_year

            me = _marginal_effect(z1, z2, z3, coeffs)

            for target_name, target_kaopen in [('EA_avg', ea_kaopen),
                                                 ('open_mean', kaopen_open_mean),
                                                 ('full', kaopen_max)]:
                delta_kaopen = target_kaopen - current_kaopen
                if delta_kaopen <= 0:
                    delta_ca = 0.0  # Already above target
                else:
                    delta_ca = me * delta_kaopen

                opening_rows.append({
                    'iso3': iso,
                    'demographics_year': year_label,
                    'current_kaopen': current_kaopen,
                    'target': target_name,
                    'target_kaopen': target_kaopen,
                    'delta_kaopen': delta_kaopen,
                    'marginal_effect': me,
                    'delta_ca_gdp': delta_ca,
                })

    opening_df = pd.DataFrame(opening_rows)
    opening_df.to_csv(OUTPUT_DIR / "tables" / "opening_scenarios.csv", index=False)
    print(f"  Opening scenarios: {len(opening_df)} rows for {len(opening_countries)} countries")

    # --- Closing scenarios (high-KAOPEN countries) ---
    closing_countries = ['USA', 'JPN', 'DEU', 'KOR', 'GBR', 'FRA', 'CAN',
                         'AUS', 'NLD', 'CHE', 'SWE', 'NOR', 'ITA', 'ESP',
                         'SGP', 'NZL', 'IRL', 'AUT', 'BEL', 'DNK']
    closing_countries = [c for c in closing_countries if c in kaopen_by_country]

    closing_rows = []
    for iso in closing_countries:
        current_kaopen = kaopen_by_country[iso]

        for proj_year in ['current', 2030, 2040, 2050]:
            if proj_year == 'current':
                row_data = latest[latest['iso3'] == iso]
                if len(row_data) == 0:
                    continue
                z1, z2, z3 = row_data.iloc[0][['Z_1', 'Z_2', 'Z_3']]
                year_label = int(row_data.iloc[0]['year'])
            else:
                proj = polys_df[(polys_df['iso3'] == iso) & (polys_df['year'] == proj_year)]
                if len(proj) == 0:
                    continue
                z1, z2, z3 = proj.iloc[0][['Z_1', 'Z_2', 'Z_3']]
                year_label = proj_year

            me = _marginal_effect(z1, z2, z3, coeffs)

            for delta_name, delta_kaopen in [('mild', -0.5), ('moderate', -1.0), ('severe', -1.5)]:
                delta_ca = me * delta_kaopen

                closing_rows.append({
                    'iso3': iso,
                    'demographics_year': year_label,
                    'current_kaopen': current_kaopen,
                    'scenario': delta_name,
                    'delta_kaopen': delta_kaopen,
                    'new_kaopen': current_kaopen + delta_kaopen,
                    'marginal_effect': me,
                    'delta_ca_gdp': delta_ca,
                })

    closing_df = pd.DataFrame(closing_rows)
    closing_df.to_csv(OUTPUT_DIR / "tables" / "closing_scenarios.csv", index=False)
    print(f"  Closing scenarios: {len(closing_df)} rows for {len(closing_countries)} countries")

    return opening_df, closing_df


def compute_global_efficiency(panel_df, polys_df, coeff_path=None):
    """
    Deliverable D: Global efficiency cost of deglobalization vs. full integration.

    Scenarios:
    - Deglobalization: all KAOPEN → sample median
    - Full integration: all KAOPEN → sample max
    - Current: baseline (no change)

    Reports gross |ΔCA| summed across all countries.
    Output: output/tables/global_openness_efficiency.csv
    """
    coeffs = _load_interaction_coefficients(coeff_path)

    # Get latest historical data per country (≤2024)
    hist = panel_df[panel_df['year'] <= 2024].dropna(subset=['kaopen', 'Z_1'])
    latest = (hist.sort_values('year')
              .groupby('iso3')
              .last()
              .reset_index())

    kaopen_median = latest['kaopen'].median()
    kaopen_max = latest['kaopen'].max()

    print(f"  Global efficiency: KAOPEN median={kaopen_median:.2f}, max={kaopen_max:.2f}")
    print(f"  Countries in analysis: {len(latest)}")

    rows = []
    for proj_year in ['current', 2030, 2040, 2050]:
        scenario_results = {'year': proj_year if proj_year != 'current' else 'current'}

        for _, r in latest.iterrows():
            iso = r['iso3']
            current_kaopen = r['kaopen']

            if proj_year == 'current':
                z1, z2, z3 = r['Z_1'], r['Z_2'], r['Z_3']
            else:
                proj = polys_df[(polys_df['iso3'] == iso) & (polys_df['year'] == proj_year)]
                if len(proj) == 0:
                    continue
                z1, z2, z3 = proj.iloc[0][['Z_1', 'Z_2', 'Z_3']]

            me = _marginal_effect(z1, z2, z3, coeffs)

            # Deglobalization: move to median
            delta_deglob = kaopen_median - current_kaopen
            dca_deglob = me * delta_deglob

            # Full integration: move to max
            delta_full = kaopen_max - current_kaopen
            dca_full = me * delta_full

            rows.append({
                'iso3': iso,
                'demographics_year': 'current' if proj_year == 'current' else proj_year,
                'current_kaopen': current_kaopen,
                'marginal_effect': me,
                'delta_ca_deglobalization': dca_deglob,
                'delta_ca_full_integration': dca_full,
            })

    detail = pd.DataFrame(rows)

    # Summary table
    summary_rows = []
    for year_val in ['current', 2030, 2040, 2050]:
        subset = detail[detail['demographics_year'] == year_val]

        if len(subset) == 0:
            continue

        gross_deglob = subset['delta_ca_deglobalization'].abs().sum()
        net_deglob = subset['delta_ca_deglobalization'].sum()
        gross_full = subset['delta_ca_full_integration'].abs().sum()
        net_full = subset['delta_ca_full_integration'].sum()
        n_countries = len(subset)

        # Countries losing/gaining from deglobalization
        n_lose_deglob = (subset['delta_ca_deglobalization'].abs() > 0.5).sum()

        summary_rows.append({
            'scenario_year': year_val,
            'n_countries': n_countries,
            'gross_reallocation_deglobalization': gross_deglob,
            'net_reallocation_deglobalization': net_deglob,
            'gross_reallocation_full_integration': gross_full,
            'net_reallocation_full_integration': net_full,
            'mean_abs_effect_deglobalization': subset['delta_ca_deglobalization'].abs().mean(),
            'mean_abs_effect_full_integration': subset['delta_ca_full_integration'].abs().mean(),
        })

    summary = pd.DataFrame(summary_rows)

    # Save both detail and summary
    detail.to_csv(OUTPUT_DIR / "tables" / "global_openness_efficiency_detail.csv", index=False)
    summary.to_csv(OUTPUT_DIR / "tables" / "global_openness_efficiency.csv", index=False)

    print(f"\n  Global efficiency summary:")
    for _, s in summary.iterrows():
        print(f"    {s['scenario_year']}: "
              f"Deglobalization gross={s['gross_reallocation_deglobalization']:.1f} pp, "
              f"Full integration gross={s['gross_reallocation_full_integration']:.1f} pp")

    return detail, summary


def compute_ge_clearing(panel_df, polys_df, baseline_model=None, feature_names=None):
    """
    General equilibrium capital market clearing overlay (Item 4.7A).

    In partial equilibrium, projected demographic CAs don't sum to zero globally.
    This function:
    1. Projects each country's demographic CA contribution (PE)
    2. Projects each country's demographic pressure on bond yields (from S1)
    3. Computes the global PE imbalance at each projection year
    4. Solves for the world interest rate adjustment that clears the market
    5. Produces GE-adjusted CA projections

    The clearing condition: sum_i [w_i * CA_i] = 0
    GE adjustment: each country's CA shifts by -delta * Δr_world
    where delta = 0.127 (Model 3b coefficient) and Δr_world clears the market.

    Returns DataFrame with PE and GE projections, plus summary of r* adjustment.
    """
    from .model import PanelGLS

    print("\n" + "=" * 70)
    print("GENERAL EQUILIBRIUM CLEARING RATE OVERLAY")
    print("=" * 70)

    # --- Step 1: Get baseline demographic coefficients ---
    # If model provided, use it; otherwise load from CSV
    if baseline_model is not None and feature_names is not None:
        z_indices = [i for i, n in enumerate(feature_names) if n.startswith('Z_')]
        z_names = [feature_names[i] for i in z_indices]
        z_betas = {z_names[i]: baseline_model.beta[z_indices[i]] for i in range(len(z_names))}
    else:
        # Load baseline coefficients from saved CSV
        coeff_path = OUTPUT_DIR / "tables" / "regression_baseline_demo_plus_eba.csv"
        if coeff_path.exists():
            coeffs = pd.read_csv(coeff_path)
            coeff_map = dict(zip(coeffs['variable'], coeffs['coefficient']))
            z_betas = {f'Z_{p}': coeff_map.get(f'Z_{p}', 0) for p in [1, 2, 3]}
        else:
            print("  ERROR: No baseline model coefficients available")
            return None, None

    z_names = ['Z_1', 'Z_2', 'Z_3']
    print(f"  Baseline Z coefficients: {z_betas}")

    # --- Step 2: Estimate S1 (demographics → bond yields) ---
    # Need estimation sample for S1
    est = panel_df[
        (panel_df['ca_gdp'].notna()) &
        (panel_df['year'] >= 1986) &
        (panel_df['year'] <= 2024)
    ].copy()

    s1_vars = z_names + ['fiscal_bal_gdp', 'expected_growth', 'nfa_gdp_lag', 'log_rel_opw']
    s1_vars = [v for v in s1_vars if v in est.columns]
    s1_df = est.dropna(subset=['real_bond_10y_diff'] + s1_vars)

    if len(s1_df) < 100:
        print("  Insufficient data for S1 estimation")
        return None, None

    s1 = PanelGLS()
    s1.fit(s1_df['real_bond_10y_diff'].values, s1_df[s1_vars].values,
           s1_df['iso3'].values, s1_df['year'].values)
    s1_z_betas = {s1_vars[i]: s1.beta[i] for i in range(len(s1_vars)) if s1_vars[i] in z_names}
    print(f"  S1 Z coefficients (Z → yield diff): {s1_z_betas}")

    # Model 3b rate-to-CA coefficient
    delta_rate = 0.127
    print(f"  Rate-to-CA coefficient (Model 3b): {delta_rate}")

    # --- Step 3: Get GDP weights ---
    gdp_data = panel_df[['iso3', 'year', 'ngdp_usd']].dropna()
    # Use latest available GDP for weighting
    latest_gdp_year = gdp_data['year'].max()
    gdp_weights = gdp_data[gdp_data['year'] == latest_gdp_year][['iso3', 'ngdp_usd']].copy()
    total_gdp = gdp_weights['ngdp_usd'].sum()
    gdp_weights['weight'] = gdp_weights['ngdp_usd'] / total_gdp
    weight_map = dict(zip(gdp_weights['iso3'], gdp_weights['weight']))
    print(f"  GDP weights from {latest_gdp_year}, {len(weight_map)} countries, "
          f"top 5: {sorted(weight_map.items(), key=lambda x: -x[1])[:5]}")

    # --- Step 4: Project PE demographic CAs and yield effects ---
    proj_years = list(range(2000, 2065, 5))
    countries = polys_df['iso3'].unique()

    rows = []
    for iso3 in countries:
        w = weight_map.get(iso3, 0)
        if w == 0:
            continue

        cdf = polys_df[polys_df['iso3'] == iso3]
        for year in proj_years:
            yr = cdf[cdf['year'] == year]
            if len(yr) == 0:
                continue

            # Demographic CA contribution (PE)
            demo_ca = sum(z_betas[zv] * yr[zv].values[0] for zv in z_names)

            # Demographic yield pressure
            demo_yield = sum(s1_z_betas[zv] * yr[zv].values[0] for zv in z_names)

            rows.append({
                'iso3': iso3,
                'year': year,
                'weight': w,
                'demo_ca_pe': demo_ca,
                'demo_yield_effect': demo_yield,
            })

    proj = pd.DataFrame(rows)

    # --- Step 5: Compute global PE imbalance and clearing rate ---
    clearing_rows = []
    for year in proj_years:
        yr_data = proj[proj['year'] == year]
        if len(yr_data) == 0:
            continue

        # Renormalize weights to sum to 1 for available countries
        w_sum = yr_data['weight'].sum()
        yr_data = yr_data.copy()
        yr_data['w_norm'] = yr_data['weight'] / w_sum

        # GDP-weighted global PE CA imbalance
        pe_imbalance = (yr_data['w_norm'] * yr_data['demo_ca_pe']).sum()

        # GDP-weighted global yield effect
        global_yield_effect = (yr_data['w_norm'] * yr_data['demo_yield_effect']).sum()

        # Clearing rate adjustment:
        # sum_i [w_i * (demo_ca_i + delta * (yield_i - Δr))] = 0
        # pe_imbalance + delta * (global_yield_effect - Δr) = 0
        # Δr = pe_imbalance/delta + global_yield_effect
        delta_r_uncapped = pe_imbalance / delta_rate + global_yield_effect

        # Cap at ±2pp — the rate channel alone cannot clear large imbalances.
        # The residual must be absorbed by other channels (FX, fiscal, structural).
        max_delta_r = 2.0
        delta_r_world = np.clip(delta_r_uncapped, -max_delta_r, max_delta_r)
        residual_imbalance = pe_imbalance + delta_rate * (global_yield_effect - delta_r_world)

        clearing_rows.append({
            'year': year,
            'pe_global_imbalance': pe_imbalance,
            'global_yield_effect': global_yield_effect,
            'delta_r_uncapped': delta_r_uncapped,
            'delta_r_world': delta_r_world,
            'residual_imbalance': residual_imbalance,
            'pct_cleared_by_rates': (1 - abs(residual_imbalance) / max(abs(pe_imbalance), 0.001)) * 100,
            'n_countries': len(yr_data),
            'weight_coverage': w_sum,
        })

    clearing = pd.DataFrame(clearing_rows)

    # --- Step 6: Compute GE-adjusted CAs ---
    delta_r_map = dict(zip(clearing['year'], clearing['delta_r_world']))

    proj['delta_r_world'] = proj['year'].map(delta_r_map)
    proj['rate_channel_pe'] = delta_rate * proj['demo_yield_effect']
    proj['rate_channel_ge'] = delta_rate * (proj['demo_yield_effect'] - proj['delta_r_world'])
    proj['demo_ca_ge'] = proj['demo_ca_pe'] + proj['rate_channel_ge']
    proj['ge_adjustment'] = proj['demo_ca_ge'] - proj['demo_ca_pe']

    # --- Print summary ---
    print("\n  Clearing rate adjustment by year:")
    print(f"  {'Year':>6} {'PE Imbalance':>14} {'Δr* (pp)':>10} {'Countries':>10}")
    for _, r in clearing.iterrows():
        print(f"  {int(r['year']):>6} {r['pe_global_imbalance']:>14.3f} "
              f"{r['delta_r_world']:>10.3f} {int(r['n_countries']):>10}")

    # Focus country comparison
    focus = ['JPN', 'DEU', 'USA', 'KOR', 'CHN', 'IND', 'IDN', 'NGA', 'BRA']
    print("\n  PE vs GE demographic CA contributions (selected countries):")
    print(f"  {'Country':>8} {'':>6} {'2020':>8} {'2030':>8} {'2040':>8} {'2050':>8} {'2060':>8}")
    for iso3 in focus:
        cdf = proj[proj['iso3'] == iso3]
        if len(cdf) == 0:
            continue
        pe_vals = {int(r['year']): r['demo_ca_pe'] for _, r in cdf.iterrows()}
        ge_vals = {int(r['year']): r['demo_ca_ge'] for _, r in cdf.iterrows()}
        pe_str = ''.join(f"{pe_vals.get(y, 0):>8.2f}" for y in [2020, 2030, 2040, 2050, 2060])
        ge_str = ''.join(f"{ge_vals.get(y, 0):>8.2f}" for y in [2020, 2030, 2040, 2050, 2060])
        print(f"  {iso3:>8} {'PE':>6} {pe_str}")
        print(f"  {'':>8} {'GE':>6} {ge_str}")

    # Save
    proj.to_csv(OUTPUT_DIR / "tables" / "ge_clearing_projections.csv", index=False)
    clearing.to_csv(OUTPUT_DIR / "tables" / "ge_clearing_rates.csv", index=False)
    print(f"\n  Saved to output/tables/ge_clearing_projections.csv")

    return proj, clearing


if __name__ == "__main__":
    print("Scenario module loaded successfully")
